#!/usr/bin/env python3

"""
This file generates faebryk/src/faebryk/library/__init__.py
"""

import logging
import re
from graphlib import TopologicalSorter
from itertools import groupby
from pathlib import Path
from typing import Iterable

logger = logging.getLogger(__name__)

REPO_ROOT = Path(__file__).parent.parent.parent
LIBRARY_DIR = REPO_ROOT / "src" / "faebryk" / "library"
OUT = LIBRARY_DIR / "_F.py"


def try_(stmt: str, exc: str | type[Exception] | Iterable[type[Exception]]):
    if isinstance(exc, type):
        exc = exc.__name__
    if not isinstance(exc, str):
        exc = f"({', '.join(e.__name__ for e in exc)})"

    return (
        f"try:\n    {stmt}\nexcept {exc} as e:\n    print('{stmt.split(' ')[-1]}', e)"
    )


def topo_sort(modules_out: dict[str, tuple[Path, str]]):
    def find_deps(module_path: Path) -> set[str]:
        f = module_path.read_text(encoding="utf-8")
        p = re.compile(r"[^a-zA-Z_0-9]F\.([a-zA-Z_][a-zA-Z_0-9]*)")
        return set(p.findall(f))

    if True:
        SRC_DIR = LIBRARY_DIR.parent.parent
        all_modules = [
            (p.stem, p) for p in SRC_DIR.rglob("*.py") if not p.stem.startswith("_")
        ]
    else:
        all_modules = [
            (module_name, module_path)
            for module_name, (module_path, _) in modules_out.items()
        ]

    topo_graph = [
        (module_name, find_deps(module_path))
        for module_name, module_path in all_modules
    ]

    # handles name collisions (of non-lib modules)
    topo_grouped = groupby(
        sorted(topo_graph, key=lambda item: item[0]), lambda item: item[0]
    )
    topo_merged = {
        name: {x for xs in group for x in xs[1]} for name, group in topo_grouped
    }

    # sort for deterministic order
    topo_graph = {
        k: sorted(v) for k, v in sorted(topo_merged.items(), key=lambda item: item[0])
    }
    order = list(TopologicalSorter(topo_graph).static_order())

    # TEST
    seen = set()
    for m in order:
        if m not in topo_graph:
            continue
        for sub in topo_graph[m]:
            if sub not in seen and sub in topo_graph:
                raise Exception(f"Collision: {sub} after {m}")
        seen.add(m)

    return [
        (module_name, modules_out[module_name][1])
        for module_name in order
        if module_name in modules_out
    ]


def main():
    assert LIBRARY_DIR.exists()

    logger.info(f"Scanning {LIBRARY_DIR} for modules")

    module_files = [p for p in LIBRARY_DIR.glob("*.py") if not p.name.startswith("_")]

    logger.info(f"Found {len(module_files)} modules")

    modules_out: dict[str, tuple[Path, str]] = {}

    # Import each module and add its class to the current namespace
    # for module_name in module_files:
    #    module = importlib.import_module(
    #        f"faebryk.library.{module_name}"  # , package=__name__
    #    )
    #    class_name = module_name
    #    if hasattr(module, class_name):
    #        # globals()[class_name] = getattr(module, class_name)
    #        modules_out[module_name] = class_name

    # assuming class name is equal to file stem
    modules_out = {
        module_path.stem: (module_path, module_path.stem)
        for module_path in module_files
    }

    modules_ordered = topo_sort(modules_out)

    logger.info(f"Found {len(modules_out)} classes")

    OUT.write_text(
        "# This file is part of the faebryk project\n"
        "# SPDX-License-Identifier: MIT\n"
        "\n"
        '"""\n'
        "This file is autogenerated by tools/library/gen_F.py\n"
        "This is the __init__.py file of the library\n"
        "All modules are in ./<module>.py with name class <module>\n"
        "Export all <module> classes here\n"
        "Do it programmatically instead of specializing each manually\n"
        "This way we can add new modules without changing this file\n"
        '"""\n'
        "\n"
        "# Disable ruff warning for whole block\n"
        "# flake8: noqa: F401\n"
        "# flake8: noqa: I001\n"
        "# flake8: noqa: E501\n"
        "\n"
        + "\n".join(
            # try_(
            f"from faebryk.library.{module} import {class_}"
            #    (AttributeError,),
            # )
            for module, class_ in modules_ordered
        )
        + "\n",
        encoding="utf-8",
    )


if __name__ == "__main__":
    logging.basicConfig(level=logging.INFO)
    main()
